%% Program: empiricalApplication.m
% - Purpose: This program implement the empirical illustration in Section 5
% of ``Inference for Cluster Randomized Experiments with Non-ignorable
% Cluster Sizes'' by Bugni, Canay, Shaikh, and Tabord-Meehan (forthcoming
% JPEM 2024), henceforth BCST24.
% - Inputs: The data file is titled ``empiricalApplicationData.mat'', and
% should be placed in the root directory.
% - Outputs: The file will produce all figures and tables in Section 5 
% (Figure 2-3 and Table 7).
% - Author: Federico A. Bugni
% - Date Created: 4/12/2024

% initialize file
clear % clear memory
close % close any plots
clc % clear screen
format bank % show results with 2 decimal places

% load data;
load dataSection5.mat
% - description: this data file was created from the file ``maindata.dta''
% from the replication file produced by Celay et al (2019, AEJ:applied),
% henceforth CGGV19. The file has 3 matrices: data1, data2, and data3
% corresponding to pre-intervention, intervention, and post-intervention
% periods, respectively (This was determined using period in CGGV19).
% Each matrix has 5 columns: the first column is equal to the week of the
% first prenatal visit (variable weekpreg_1st in CGGV19), the second column
% indicates whether week of the first prenatal visit is less than 13
% (week... in CGGV19), the third column is equal to treatment assigment
% (Z in CGGV19), the fourth column is equal to clinic number (id_clinic
% in CGGV19), and the fifth column is the clinic size and was constructed
% by summing the number of patients attending the clinic within each
% dataset.

%% Figure 2 in BCST24: Histogram of the week of the first prenatal visit during the pre-intervention period.
data = data1; % load data of pre-intervention period;
Y_data = data(:,1); % column 1: first prenatal visit
fig = figure(); % define figure name
histogram(Y_data,'Normalization','count','BinWidth',2,'FaceColor','k') % histogram in black and white
set(gcf,'Units','centimeters ','Position',[10 10 20*1.2 13*1.2]); % format histogram
print(fig,'-dpng','-r200',strcat('histogram_week.png')); % save histogram

% The commands below display some summary information:
disp('First prenatal visit in pre-intervention period:')
disp('          min,         average,      max,           std dev,      freq.<13')
disp([min(Y_data),mean(Y_data),max(Y_data),sqrt(var(Y_data)),mean(Y_data<13)])

%% Figure 3 in BCST24: Histogram of the patients per clinic having a first prenatal visit during the intervention period.
data = data2; % load data of intervention period;
Y_data = data(:,1); % column 1: first prenatal visit
A_data = data(:,3); % column 3: treatment assignment
g_data = data(:,4); % column 4: clinic id (cluster number)
N_data = data(:,5); % column 5: sample size of cluster

% here we compute sample of each cluster
g_list = unique(g_data); % unique list of clusters;
G = size(g_list,1); % total number of clusters;

% for each cluster, we verify that that the cluster size is unique and assign it to the cluster id.
N = nan(size(g_list));
for g = 1:size(g_list,1)
    if var(N_data(g_data == g_list(g)))>0 % check non unique N(g)
        disp(['Error: multiple N(g) for g = ',num2str(g)])
    end
    N(g) = mean(N_data(g_data == g_list(g)));
end

close
fig = figure(); % define figure name
histogram(N,'Normalization','count','BinWidth',5,'FaceColor','k') % histogram in black and white
set(gcf,'Units','centimeters ','Position',[10 10 20*1.2 13*1.2]); % format histogram
print(fig,'-dpng','-r200',strcat('histogram_clinic.png')); % save histogram

% The commands below display some summary information:
disp('Cluster size in the intervention period:')
disp('          min,         average,      max,          std dev')
disp([min(N),mean(N),max(N),sqrt(var(N))])

%% Table 7 in BCST24: Estimation results based on data from Celhay et al. (2019)
results = []; % initialize results table
for period = 1:3 % loop over periods
    for outcome =1:2 % loop over outcome variables
        % for each period, load the correct data;
        if period == 1
            data = data1;
        elseif period == 2
            data = data2;
        elseif period == 3
            data = data3;
        end

        % define the variables
        Y_data = data(:,outcome); % outcome: column 1 or 2
        A_data = data(:,3); % column 3: treatment assignment
        g_data = data(:,4); % column 4: clinic id (cluster number)
        N_data = data(:,5); % column 5: sample size of cluster

        % define the significance levelts
        alphas = [0.01;0.05;0.1];

        % define variables:
        g_list = unique(g_data); % list of clusters;
        G = size(g_list,1); % total number of clusers;

        % initialize variables
        Ybar = nan(size(g_list));
        A = nan(size(g_list));
        N = nan(size(g_list));

        % theta1: equally weighted ATE;
        for g = 1:G % loop over cluster

            Ybar(g) = mean(Y_data(g_data == g_list(g))); % compute Ybar(g)
            if var(A_data(g_data == g_list(g)))>0 % check non unique A(g)
                disp(['Error: multiple A(g) for g = ',num2str(g)])
            end            
            A(g) = mean(A_data(g_data == g_list(g))); % compute A(g)

            
            if var(N_data(g_data == g_list(g)))>0 % check non unique N(g)
                disp(['Error: multiple N(g) for g = ',num2str(g)])
            end
            N(g) = mean(N_data(g_data == g_list(g)));% compute N(g)


        end
        % estimator of theta1
        theta1_hat = sum(A.*Ybar)/sum(A) - sum((1-A).*Ybar)/sum(1-A);

        % s.e. of theta1
        hatVar1 = sum(A.*Ybar.^2)/sum(A) - (sum(A.*Ybar)/sum(A)).^2;
        hatVar0 = sum((1-A).*Ybar.^2)/sum(1-A) - (sum((1-A).*Ybar)/sum(1-A)).^2;
        sigma1_hat = sqrt(hatVar1/mean(A) + hatVar0/mean(1-A)); % sigma_hat

        % CI: confidence interval for theta1
        LB_1 = theta1_hat - norminv(1-alphas/2)*sigma1_hat/sqrt(G);
        UB_1 = theta1_hat + norminv(1-alphas/2)*sigma1_hat/sqrt(G);

        % define column for table: col1 for theta1
        col1 = [period;outcome;1;theta1_hat;sigma1_hat/sqrt(G);LB_1;UB_1];

        % theta2: cluster-level ATE;
        % estimator of theta1
        theta2_hat = sum(A.*N.*Ybar)/sum(N.*A) - sum((1-A).*N.*Ybar)/sum(N.*(1-A));

        % s.e. of theta2
        ehat1 = nan(size(g_data));
        ehat0 = nan(size(g_data));
        for i=1:size(g_data,1)
            ehat1(i) = Y_data(i) - sum(A.*N.*Ybar)/sum(N.*A);
            ehat0(i) = Y_data(i) - sum((1-A).*N.*Ybar)/sum(N.*(1-A));
        end

        sum_ehat1 = nan(size(g_list));
        sum_ehat0 = nan(size(g_list));
        for g = 1:G
            sum_ehat1(g) = sum(ehat1(g_data == g_list(g)));
            sum_ehat0(g) = sum(ehat0(g_data == g_list(g)));
        end

        hat2Var1 = mean((sum_ehat1.^2).*A)/(mean(N.*A)^2);
        hat2Var0 = mean((sum_ehat0.^2).*(1-A))/(mean(N.*(1-A))^2);
        sigma2_hat = sqrt(hat2Var1+hat2Var0);

        % CI for theta2
        LB_2 = theta2_hat - norminv(1-alphas/2)*sigma2_hat/sqrt(G);
        UB_2 = theta2_hat + norminv(1-alphas/2)*sigma2_hat/sqrt(G);

        % define column for table: col2 for theta2
        col2 = [period;outcome;2;theta2_hat;sigma2_hat/sqrt(G);LB_2;UB_2];

        % collect columns for both theta's
        results = [results,[col1,col2]];

    end
end
stars = sum(results(6:8,:)>0) + sum(results(9:11,:)<0); % compute signficance stars;
results1 = [results(1:5,:);stars]; % table of results without CI
results2 = [results(1:5,:);stars;results(6:end,:)]; % table of results with CI

% display results on screen
disp('Table 7 with estimates, s.e., and CI. 12 rows are as follows:')
disp(' - Row 1: period: 1 for pre-intervention, 2 for intervention, and 3 for post-intervention')
disp(' - Row 2: outcome: 1 for weeks for first prenatal visit and 2 for week of visit<13')
disp(' - Row 3: estimand: 1 for theta1 and 2 for theta2')
disp(' - Row 4: estimate')
disp(' - Row 5: s.e., already divided by sqrt(G)')
disp(' - Row 6: stars: 0 for not sign, 1 for sign at 90%, 2 for sign at 95%, 3 for sign at 99%')
disp(results1)

% save Table7 in an csv table
Rnames = {'period','outcome','parameter','estimate','s.e.','stars'};
resultsTable = table(results1,'RowNames',Rnames);
writetable(resultsTable,'Table7.csv', 'Delimiter', ',', 'WriteRowNames', true, 'Encoding', 'UTF-8');